/*
 * Copyright (C) 2018-2020. Huawei Technologies Co., Ltd. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.prestosql.query;

import com.google.common.cache.Cache;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.connector.CatalogName;
import io.prestosql.connector.informationschema.InformationSchemaTransactionHandle;
import io.prestosql.connector.system.GlobalSystemTransactionHandle;
import io.prestosql.connector.system.SystemTransactionHandle;
import io.prestosql.cost.CostCalculator;
import io.prestosql.cost.StatsCalculator;
import io.prestosql.dynamicfilter.DynamicFilterService;
import io.prestosql.execution.LocationFactory;
import io.prestosql.execution.NodeTaskMap;
import io.prestosql.execution.QueryPreparer;
import io.prestosql.execution.QueryStateMachine;
import io.prestosql.execution.RemoteTaskFactory;
import io.prestosql.execution.SqlQueryExecution;
import io.prestosql.execution.scheduler.ExecutionPolicy;
import io.prestosql.execution.scheduler.NodeScheduler;
import io.prestosql.execution.scheduler.SplitSchedulerStats;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.failuredetector.FailureDetector;
import io.prestosql.heuristicindex.HeuristicIndexerManager;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.TableHandle;
import io.prestosql.operator.ReuseExchangeOperator;
import io.prestosql.security.AccessControl;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.connector.ColumnMetadata;
import io.prestosql.spi.connector.ConnectorTransactionHandle;
import io.prestosql.spi.connector.Constraint;
import io.prestosql.spi.session.PropertyMetadata;
import io.prestosql.spi.statistics.TableStatistics;
import io.prestosql.spi.type.Type;
import io.prestosql.split.SplitManager;
import io.prestosql.sql.analyzer.Analysis;
import io.prestosql.sql.analyzer.QueryExplainer;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.LogicalPlanner;
import io.prestosql.sql.planner.NodePartitioningManager;
import io.prestosql.sql.planner.Partitioning;
import io.prestosql.sql.planner.PartitioningHandle;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.Plan;
import io.prestosql.sql.planner.PlanFragmenter;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.iterative.IterativeOptimizer;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.BeginTableWrite;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.tree.CreateIndex;
import io.prestosql.sql.tree.CreateTable;
import io.prestosql.sql.tree.CreateTableAsSelect;
import io.prestosql.sql.tree.CurrentPath;
import io.prestosql.sql.tree.CurrentTime;
import io.prestosql.sql.tree.CurrentUser;
import io.prestosql.sql.tree.DefaultTraversalVisitor;
import io.prestosql.sql.tree.Query;
import io.prestosql.sql.tree.Statement;
import io.prestosql.statestore.StateStoreProvider;
import io.prestosql.transaction.TransactionId;
import io.prestosql.utils.OptimizerUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;

import static io.prestosql.SystemSessionProperties.isExecutionPlanCacheEnabled;

public class CachedSqlQueryExecution
        extends SqlQueryExecution
{
    private final Optional<Cache<Integer, CachedSqlQueryExecutionPlan>> cache; // cache key is generated by SqlQueryExecutionCacheKeyGenerator
    private final BeginTableWrite beginTableWrite;

    public CachedSqlQueryExecution(QueryPreparer.PreparedQuery preparedQuery, QueryStateMachine stateMachine,
                                   String slug, Metadata metadata, AccessControl accessControl, SqlParser sqlParser, SplitManager splitManager,
                                   NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler,
                                   List<PlanOptimizer> planOptimizers, PlanFragmenter planFragmenter, RemoteTaskFactory remoteTaskFactory,
                                   LocationFactory locationFactory, int scheduleSplitBatchSize, ExecutorService queryExecutor,
                                   ScheduledExecutorService schedulerExecutor, FailureDetector failureDetector, NodeTaskMap nodeTaskMap,
                                   QueryExplainer queryExplainer, ExecutionPolicy executionPolicy, SplitSchedulerStats schedulerStats,
                                   StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector,
                                   DynamicFilterService dynamicFilterService, Optional<Cache<Integer, CachedSqlQueryExecutionPlan>> cache,
                                   HeuristicIndexerManager heuristicIndexerManager, StateStoreProvider stateStoreProvider)
    {
        super(preparedQuery, stateMachine, slug, metadata, accessControl, sqlParser, splitManager,
                nodePartitioningManager, nodeScheduler, planOptimizers, planFragmenter, remoteTaskFactory, locationFactory,
                scheduleSplitBatchSize, queryExecutor, schedulerExecutor, failureDetector, nodeTaskMap, queryExplainer,
                executionPolicy, schedulerStats, statsCalculator, costCalculator, warningCollector, dynamicFilterService, heuristicIndexerManager, stateStoreProvider);
        this.cache = cache;
        this.beginTableWrite = new BeginTableWrite(metadata);
    }

    @Override
    protected Plan createPlan(Analysis analysis, Session session, List<PlanOptimizer> planOptimizers,
                              PlanNodeIdAllocator idAllocator, Metadata metadata, TypeAnalyzer typeAnalyzer, StatsCalculator statsCalculator,
                              CostCalculator costCalculator, WarningCollector warningCollector)
    {
        Statement statement = analysis.getStatement();

        // Get relevant Session properties which may affect the resulting execution plan
        Map<String, Object> systemSessionProperties = new HashMap<>(); // Property to property value mapping
        SystemSessionProperties sessionProperties = new SystemSessionProperties();
        for (PropertyMetadata<?> property : sessionProperties.getSessionProperties()) {
            systemSessionProperties.put(property.getName(), session.getSystemProperty(property.getName(), property.getJavaType()));
        }

        // if the original statement before rewriting is CreateIndex, set session to let connector know that pageMetadata should be enabled
        if (analysis.getOriginalStatement() instanceof CreateIndex) {
            session.setPageMetadataEnabled(true);
        }

        // build list of fully qualified table names
        List<String> tableNames = new ArrayList<>();
        Map<String, TableStatistics> tableStatistics = new HashMap<>();
        // Get column name to column type to detect column type changes between queries more easily
        Map<String, Type> columnTypes = new HashMap<>();
        // Cacheable conditions:
        // 1. Caching must be enabled globally
        // 2. Caching must be enabled in the session
        // 3. There must not be any parameters in the query
        //      TODO: remove requirement for empty params and implement parameter rewrite
        // 4. Methods in ConnectorTableHandle and ConnectorMetadata must be
        //     overwritten to allow access to fully qualified table names and column names
        // 5. Statement must be an instance of Query and not contain CurrentX functions
        boolean cacheable = this.cache.isPresent() &&
                isExecutionPlanCacheEnabled(session) &&
                analysis.getParameters().isEmpty() &&
                validateAndExtractTableAndColumns(analysis, metadata, session, tableNames, tableStatistics, columnTypes) &&
                isCacheable(statement) &&
                (!(analysis.getOriginalStatement() instanceof CreateIndex)); // create index should not be cached

        cacheable = cacheable && !tableNames.isEmpty();
        if (!cacheable) {
            return super.createPlan(analysis, session, planOptimizers, idAllocator, metadata, typeAnalyzer,
                    statsCalculator, costCalculator, warningCollector);
        }

        List<String> optimizers = new ArrayList<>();
        // build list of enabled optimizers and rules for cache key
        for (PlanOptimizer planOptimizer : planOptimizers) {
            if (planOptimizer instanceof IterativeOptimizer) {
                IterativeOptimizer iterativeOptimizer = (IterativeOptimizer) planOptimizer;
                Set<Rule<?>> rules = iterativeOptimizer.getRules();
                for (Rule rule : rules) {
                    if (OptimizerUtils.isEnabledRule(rule, session)) {
                        optimizers.add(rule.getClass().getSimpleName());
                    }
                }
            }
            else {
                if (OptimizerUtils.isEnabledLegacy(planOptimizer, session)) {
                    optimizers.add(planOptimizer.getClass().getSimpleName());
                }
            }
        }

        Set<String> connectors = tableNames.stream().map(table -> table.substring(0, table.indexOf("."))).collect(Collectors.toSet());
        connectors.stream().forEach(connector -> {
            for (Map.Entry<String, String> property : session.getConnectorProperties(new CatalogName(connector)).entrySet()) {
                systemSessionProperties.put(connector + "." + property.getKey(), property.getValue());
            }
        });

        Plan plan;
        // TODO: Traverse the statement to build the key then combine tables/optimizers.. etc
        int key = SqlQueryExecutionCacheKeyGenerator.buildKey((Query) statement, tableNames, optimizers, columnTypes, session.getTimeZoneKey(), systemSessionProperties);
        CachedSqlQueryExecutionPlan cachedPlan = this.cache.get().getIfPresent(key);

        HetuLogicalPlanner logicalPlanner = new HetuLogicalPlanner(session, planOptimizers, idAllocator,
                metadata, typeAnalyzer, statsCalculator, costCalculator, warningCollector);

        PlanNode root;
        plan = cachedPlan != null ? cachedPlan.getPlan() : null;
        // To handle the chance of cache key collision, the timezone and the statement between
        // the cached plan and the session are verified for greater confidence.
        // Timezone must be matched in order to preserve the correctness for queries containing functions
        // that rely on system time
        if (plan != null && cachedPlan.getTimeZoneKey().equals(session.getTimeZoneKey()) &&
                cachedPlan.getStatement().equals(statement) && session.getTransactionId().isPresent() && cachedPlan.getIdentity().getUser().equals(session.getIdentity().getUser())) { // TODO: traverse the statement and accept partial match
            root = plan.getRoot();
            try {
                if (!cachedPlan.getTableStatistics().equals(tableStatistics)) {
                    // TableStatistics have changed, therefore the cached plan may no longer be applicable
                    throw new NoSuchElementException();
                }
                // TableScanNode may contain the old transaction id.
                // The following logic rewrites the logical plan by replacing the TableScanNode with a new TableScanNode which
                // contains the new transaction id from session.
                root = SimplePlanRewriter.rewriteWith(new TableHandleRewriter(session, analysis, metadata), root);
            }
            catch (NoSuchElementException e) {
                // Cached plan is outdated
                // invalidate cache
                this.cache.get().invalidateAll();
                // Build a new plan
                plan = createAndCachePlan(key, logicalPlanner, statement, tableNames, tableStatistics, optimizers, analysis, columnTypes, systemSessionProperties);
                root = plan.getRoot();
            }
        }
        else {
            // Build a new plan
            plan = createAndCachePlan(key, logicalPlanner, statement, tableNames, tableStatistics, optimizers, analysis, columnTypes, systemSessionProperties);
            root = plan.getRoot();
        }
        // BeginTableWrite optimizer must be run at the end as the last optimization
        // due to a hack Hetu community added which also serves to updates
        // metadata in the nodes
        root = this.beginTableWrite.optimize(root, session, null, null, null, null);
        plan = update(plan, root);

        return plan;
    }

    private Plan createAndCachePlan(
            int key,
            LogicalPlanner logicalPlanner,
            Statement statement,
            List<String> tableNames,
            Map<String, TableStatistics> tableStatistics,
            List<String> planOptimizers,
            Analysis analysis,
            Map<String, Type> columnTypes,
            Map<String, Object> systemSessionProperties)
    {
        // build a new plan
        Plan plan = logicalPlanner.plan(analysis);
        // Cache the plan
        CachedSqlQueryExecutionPlan newCachedPlan = new CachedSqlQueryExecutionPlan(statement, tableNames, tableStatistics, planOptimizers, plan,
                analysis.getParameters(), columnTypes, getSession().getTimeZoneKey(), getSession().getIdentity(), systemSessionProperties);
        this.cache.get().put(key, newCachedPlan);
        return plan;
    }

    private boolean validateAndExtractTableAndColumns(
            Analysis analysis,
            Metadata metadata,
            Session session,
            List<String> tables,
            Map<String, TableStatistics> tableStatistics,
            Map<String, Type> columns)
    {
        for (TableHandle tableHandle : analysis.getTables()) {
            // read metadata to see if plan caching is supported by the connector
            try {
                if (metadata.isExecutionPlanCacheSupported(session, tableHandle)) {
                    tables.add(tableHandle.getFullyQualifiedName());
                    tableStatistics.put(tableHandle.getFullyQualifiedName(), metadata.getTableStatistics(session, tableHandle, Constraint.alwaysTrue())); // TODO: Find a way to get constraints instead of reading all table statistics

                    Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, tableHandle);
                    for (ColumnHandle columnHandle : columnHandles.values()) {
                        ColumnMetadata columnMetadata = metadata.getColumnMetadata(session, tableHandle, columnHandle);
                        columns.put(tableHandle.getFullyQualifiedName() + "." + columnHandle.getColumnName(), columnMetadata.getType());
                    }
                }
                else {
                    return false;
                }
            }
            catch (PrestoException e) {
                // TableStatistics constraint -- cannot query more than 1000 hive partitions
                return false;
            }
        }
        return true;
    }

    private boolean isCacheable(Statement statement)
    {
        // Skip cache when creating tables, hack for outdated metadata
        if (!(statement instanceof Query)) {
            return false;
        }

        try {
            // filter out create table statements and statements which contain CurrentX functions
            new StatementChecker().process(statement, null);
        }
        catch (UnsupportedOperationException e) {
            return false;
        }
        return true;
    }

    private static Plan update(Plan currentPlan, PlanNode root)
    {
        // Rebuild Plan object to get a new plan ID
        return new Plan(root, currentPlan.getTypes(), currentPlan.getStatsAndCosts()); // TODO: need to update Types for parameter rewrite
    }

    private static class StatementChecker
            extends DefaultTraversalVisitor<Void, Void>
    {
        @Override
        protected Void visitCurrentPath(CurrentPath node, Void context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        protected Void visitCurrentTime(CurrentTime node, Void context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        protected Void visitCurrentUser(CurrentUser node, Void context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        protected Void visitCreateTable(CreateTable node, Void context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        protected Void visitCreateTableAsSelect(CreateTableAsSelect node, Void context)
        {
            throw new UnsupportedOperationException();
        }
    }

    private static class TableHandleRewriter
            extends SimplePlanRewriter<Void>
    {
        private final TransactionId transactionId;

        private final Session session;
        private final Analysis analysis;
        private final Metadata metadata;
        private final Map<String, TableHandle> tables;
        private Map<ConnectorTransactionHandle, ConnectorTransactionHandle> connectorTransactionHandleMap; // Old ConnectorTransactionHandle from cached plan to new ConnectorTransactionHandle from metadata

        TableHandleRewriter(Session session, Analysis analysis, Metadata metadata)
        {
            this.transactionId = session.getTransactionId().get();
            this.session = session;
            this.analysis = analysis;
            this.metadata = metadata;
            connectorTransactionHandleMap = new HashMap<>();

            // A map of String fully qualified names to TableHandles for ease of access
            Map<String, TableHandle> tables = new HashMap<>();
            for (TableHandle handle : analysis.getTables()) {
                tables.put(handle.getFullyQualifiedName(), handle);
            }
            this.tables = tables;
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context)
        {
            TableHandle tableHandle = node.getTable(); // old table handle
            TableHandle newTableHandle = toNewTableHandle(tableHandle, tables, transactionId);
            // Build mapping of cached ConnectorTransactionHandle to new ConnectorTransactionHandle so it can be replaced
            // in other nodes if necessary
            connectorTransactionHandleMap.put(tableHandle.getTransaction(), newTableHandle.getTransaction());

            // Return a new table handle with the ID, output symbols, assignments, and enforced constraints of the cached table handle
            return new TableScanNode(node.getId(), newTableHandle, node.getOutputSymbols(), node.getAssignments(), node.getEnforcedConstraint(), node.getPredicate(), ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, 0, 0);
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, RewriteContext<Void> context)
        {
            PartitioningHandle oldPartitioningHandle = node.getPartitioningScheme().getPartitioning().getHandle();
            if (oldPartitioningHandle.getTransactionHandle().isPresent()) {
                // Visit children nodes first to update their transaction handles
                List<PlanNode> children = node.getSources().stream()
                        .map(x -> context.defaultRewrite(x))
                        .collect(Collectors.toList());

                PartitioningHandle partitioningHandle = new PartitioningHandle(oldPartitioningHandle.getConnectorId(),
                        Optional.of(connectorTransactionHandleMap.get(oldPartitioningHandle.getTransactionHandle().get())),
                        oldPartitioningHandle.getConnectorHandle());

                List<Symbol> columns = new ArrayList<>();
                columns.addAll(node.getPartitioningScheme().getPartitioning().getColumns());

                Partitioning partitioning = Partitioning.create(partitioningHandle, columns);
                PartitioningScheme partitioningScheme = new PartitioningScheme(partitioning, node.getPartitioningScheme().getOutputLayout());
                System.out.println("New partitioning handle ID: " + partitioningHandle.getTransactionHandle().get().toString());
                return new ExchangeNode(node.getId(),
                        node.getType(),
                        node.getScope(),
                        partitioningScheme,
                        children,
                        node.getInputs(),
                        node.getOrderingScheme());
            }
            else {
                return super.visitExchange(node, context);
            }
        }

        private static TableHandle toNewTableHandle(TableHandle oldTableHandle, Map<String, TableHandle> tables, TransactionId transactionId)
        {
            // Look up old table handle in the current session
            TableHandle newTableHandle = tables.get(oldTableHandle.getFullyQualifiedName());

            // New connector transaction handle may not have the correct transaction ID so it is explicitly rewritten here
            return new TableHandle(oldTableHandle.getCatalogName(),
                    newTableHandle.getConnectorHandle().createFrom(oldTableHandle.getConnectorHandle()),
                    toNewConnectorTransactionHandle(newTableHandle, transactionId), newTableHandle.getLayout());
        }

        private static ConnectorTransactionHandle toNewConnectorTransactionHandle(TableHandle tableHandle, TransactionId transactionId)
        {
            ConnectorTransactionHandle transactionHandle = tableHandle.getTransaction();
            if (transactionHandle instanceof GlobalSystemTransactionHandle) {
                return new GlobalSystemTransactionHandle(transactionId);
            }
            else if (transactionHandle instanceof SystemTransactionHandle) {
                return new SystemTransactionHandle(transactionId, toNewConnectorTransactionHandle(tableHandle, transactionId));
            }
            else if (transactionHandle instanceof InformationSchemaTransactionHandle) {
                return new InformationSchemaTransactionHandle(transactionId);
            }
            else {
                // By default this method returns the original transaction handle unless overwritten
                // Some connector transaction handles (such as hive) cannot be reused between queries
                return tableHandle.getTransaction();
            }
        }
    }
}
